#=====================================
#========Human Category Chains========
#=====================================

#====================
#====Front Matter====
#====================

library(rstan)
library(rethinking)

#========================
#====Custom Functions====
#========================

simchain_density <- function(p){
	if (is.null(dim(p)[1])) {
		p <- matrix(p,nrow=1,ncol=length(p))
	}
	d <- rep(0,dim(p)[1])
	for (i in 1:(dim(p)[2]-1)){
		m <- i*(1-p[,i+1])
		j <- i
		while (j > 0) {
			m <- m*p[,j]
			j <- j-1
		}
		d <- d + m
	}
	d <- d + dim(p)[2]*apply(p,1,prod)
}

asymm_logit_learning <- function(x, flr, shift, slope, twist) {
	outcome <- log(inv_logit(x*slope + shift))*twist
	outcome[flr < 0.999] <- exp(outcome[flr < 0.999] + log(0.999-flr[flr < 0.999])) + flr
	outcome[flr >= 0.999] <- 0.999
	return(outcome)
}

#===================
#====Stan Models====
#===================

human_cat_chain_model_string <- "
functions {
	real dsim_log(int[] y, int N, int P, real[,] prob) {
		// Computes the log-likelihood for a process-of-elimination serial task (such as SimChain)
		vector[N] q;
		for (i in 1:N) {
			q[i] <- 0;
			for (j in 1:y[i]) {
				q[i] <- q[i] + log(prob[i,j]);        // Each correct responses in a sequence 1:j during trial i corresponds to success with a probability of p[i,j]
			}
			if (y[i] < P) {
				q[i] <- q[i] + log(1-prob[i,y[i]+1]); // If an error is made, the log-likelihood must be incremented by a failure with a probability of (1-p[i,y[i]+1])
			}
		}
		return sum(q);
	}
}
data{
	int<lower=1> N;              // Number of observations
	int<lower=1> S;              // Number of subjects
	int<lower=1> P;              // Number of list elements
	int<lower=1> Slist[N];       // List of subject indices
	real trial[N];               // List of trial numbers
	int<lower=0> progress[N];    // List of trial progressions
}
transformed data {
	int Pmn;
	int K;
	Pmn <- P-1;      // Pmn is a placeholder value for "all but the last response"
	K <- 4*Pmn + 1;  // K is used for vectors containing the full list of population parameters
}
parameters{
	real m_vec[S*Pmn];                     // z-scores for subject-level shift parameter for Stims 1:Pmn
	real s_vec[S*Pmn];                     // z-scores for subject-level slope parameter for Stims 1:Pmn
	real v_vec[S*Pmn];                     // z-scores for subject-level twist parameter for Stims 1:Pmn
	real f_vec[S*P];                       // z-scores for subject-level floor parameter for Stims 1:P
	vector[K] z;                           // z-scores for scale factor for parameter covariance matrix
	cholesky_factor_corr[K] L_Omega;       // Cholesky-factored correlation matrix for parameter covariance matrix
	vector<lower=0>[K] tau;                // Covariance scale factor for population-level parameters
	vector[K] bigmu;                       // Means for population-level parameters
	vector<lower=0>[Pmn] sigma_m;          // Dispersion term for subject-level shift parameters
	vector<lower=0, upper=4>[Pmn] sigma_s; // Dispersion term for subject-level slope parameters; heavily constrained to prevent divergence
	vector<lower=0, upper=4>[Pmn] sigma_v; // Dispersion term for subject-level twist parameters; heavily constrained to prevent divergence
	vector<lower=0, upper=6>[P]   sigma_f; // Dispersion term for subject-level floor parameters; heavily constrained to prevent divergence
}
transformed parameters {
	vector[K] params_vec;         // Population-level parameter vector
	matrix[K,K] L_Sigma;          // Population-level covariance matrix
	real m[S,Pmn];                // Subject-level shift for Stims 1:Pmn
	real<lower=0> s[S,Pmn];       // Subject-level slope for Stims 1:Pmn
	real<lower=0> v[S,Pmn];       // Subject-level twist for Stims 1:Pmn
	real f[S,P];                  // Subject-level floor for Stims 1:P
	real mu_m[Pmn];               // Population central tendency for shift for Stims 1:Pmn
	real mu_s[Pmn];               // Population central tendency for slope for Stims 1:Pmn
	real mu_v[Pmn];               // Population central tendency for twist for Stims 1:Pmn
	real mu_f[P];                 // Population central tendency for floor for Stims 1:P

	L_Sigma <- multiply_lower_tri_self_transpose(L_Omega);
	params_vec <- bigmu + (quad_form_diag(L_Sigma,tau) * z); 

	for ( j in 1:Pmn ) {
		mu_m[j] <- params_vec[j];
		mu_s[j] <- params_vec[j+Pmn];
		mu_v[j] <- params_vec[j+2*Pmn];
	}
	for ( j in 1:P ) {
		mu_f[j] <- params_vec[j+3*Pmn];
	}

	for ( i in 1:S ) {
		for ( j in 1:Pmn ) {
			m[i,j] <-     mu_m[j] + sigma_m[j] * m_vec[i+(j-1)*S];
			s[i,j] <- exp(mu_s[j] + sigma_s[j] * s_vec[i+(j-1)*S]);
			v[i,j] <- exp(mu_v[j] + sigma_v[j] * v_vec[i+(j-1)*S]);
		}
		for ( j in 1:P ) {
			f[i,j] <-     mu_f[j] + sigma_f[j] * f_vec[i+(j-1)*S];
		}
	}

}
model{
	// implies params_vec ~ multi_normal( bigmu , Sigma_matrix )
	// mu_m, mu_s, mu_v, and mu_f are then extracted from params_vec for convenience
	tau   ~ normal( 0 , 1 );
	z     ~ normal( 0 , 1 );
	bigmu ~ normal( 0 , 2 );
	L_Omega ~ lkj_corr_cholesky(2);

    // implies m ~ normal( mu_m , sigma_m )
	sigma_m ~ normal( 0 , 1 );
	m_vec   ~ normal( 0 , 1 );

    // implies s ~ exp(normal( mu_s , sigma_s ))
  	sigma_s ~ normal( 0 , 1 );
	s_vec   ~ normal( 0 , 1 );

    // implies v ~ exp(normal( mu_v , sigma_v ))
  	sigma_v ~ normal( 0 , 1 );
	v_vec   ~ normal( 0 , 1 );

    // implies f ~ normal( mu_f , sigma_f )
  	sigma_f ~ normal( 0 , 1 );
	f_vec   ~ normal( 0 , 1 );

	{
		real p[N,P];
		real bot;
		for ( i in 1:N ) {
			p[i,P] <- inv_logit(f[Slist[i],P]);
			for ( j in 1:Pmn ) {
				p[i,j] <- inv_logit(trial[i]*s[Slist[i],j] + m[Slist[i],j])^v[Slist[i],j];
				bot <- inv_logit(f[Slist[i],j]);
				p[i,j] <- bot + (0.999-bot)*p[i,j];
			}
		}
		increment_log_prob(dsim_log(progress,N,P,p));
	}
}
"

human_cat_chain_RT_string <- "
data{
	int<lower=0> N;              // Number of observations
	int<lower=0> S;              // Number of subjects
	int<lower=0> P;              // Number of list stimuli
	int<lower=0> Slist[N];       // List of subject indices
	int trial[N];                // List of trial numbers
	int<lower=0> position[N];    // List of response positions
	real react[N];      	     // List of log reaction times
}
parameters{
	real k_vec[S*P];
	real b_vec[S*P];
	real<lower=0> s_rt[S,P];    // sigma log reaction times
	real mu_k[P];               // Population central tendency for k_rt
	real mu_b[P];               // Population central tendency for b_rt
	real<lower=0> sigma_k[P];   // Population dispersion for k_rt
	real<lower=0> sigma_b[P];   // Population dispersion for b_rt
}
transformed parameters {
	real k_rt[S,P];             // intercept of log reaction times
	real b_rt[S,P];             // slope log reaction times
	for ( i in 1:S ) {
		for ( j in 1:P ) {
			k_rt[i,j] <- mu_k[j] + sigma_k[j] * k_vec[i+(j-1)*S];
			b_rt[i,j] <- mu_b[j] + sigma_b[j] * b_vec[i+(j-1)*S];
		}
	}
}
model{
	mu_k ~ normal( 0 , 20 );
	mu_b ~ normal( 0 , 20 );
	sigma_k ~ normal( 0 , 5 );
	sigma_b ~ normal( 0 , 5 );
	k_vec ~ normal( 0 , 1 );
	b_vec ~ normal( 0 , 1 );
	for ( i in 1:S ) {
		for ( j in 1:P ) {
			s_rt[i,j] ~ exponential( 1 );
		}
	}
	for ( i in 1:N ) {
		increment_log_prob(normal_log(react[i], k_rt[Slist[i],position[i]] + b_rt[Slist[i],position[i]]*trial[i], s_rt[Slist[i],position[i]] ));
	}
}
"

#========================================
#====Set Working Directory, Load Data====
#========================================

HCC <- read.csv("HumanCategoryChain.csv")
data_ph <- list(N=length(HCC$Subject[HCC$Phase==1]), S=length(unique(HCC$Subject[HCC$Phase==1])), P=4, T=max(HCC$Trial[HCC$Phase==1])+1, Slist=HCC$Subject[HCC$Phase==1], trial=HCC$Trial[HCC$Phase==1]+1, progress=HCC$Progress[HCC$Phase==1])
data_pa <- list(N=length(HCC$Subject[HCC$Phase==2]), S=length(unique(HCC$Subject[HCC$Phase==2])), P=4, T=max(HCC$Trial[HCC$Phase==2])+1, Slist=HCC$Subject[HCC$Phase==2], trial=HCC$Trial[HCC$Phase==2]+1, progress=HCC$Progress[HCC$Phase==2])

rt_data_ph <- list(N=4*length(HCC$Subject[HCC$Phase==1]), S=length(unique(HCC$Subject[HCC$Phase==1])), P=4, Slist=rep(HCC$Subject[HCC$Phase==1],4), trial=rep(HCC$Trial[HCC$Phase==1]-60,4), position=c(rep(1,length(HCC$Subject[HCC$Phase==1])),rep(2,length(HCC$Subject[HCC$Phase==1])),rep(3,length(HCC$Subject[HCC$Phase==1])),rep(4,length(HCC$Subject[HCC$Phase==1]))), react=c(HCC$React1[HCC$Phase==1],HCC$React2[HCC$Phase==1],HCC$React3[HCC$Phase==1],HCC$React4[HCC$Phase==1]))
rt_data_ph$Slist <- rt_data_ph$Slist[!is.na(rt_data_ph$react)]
rt_data_ph$position <- rt_data_ph$position[!is.na(rt_data_ph$react)]
rt_data_ph$trial <- rt_data_ph$trial[!is.na(rt_data_ph$react)]
rt_data_ph$react <- log(rt_data_ph$react[!is.na(rt_data_ph$react)])
rt_data_ph$N <- length(rt_data_ph$Slist)

rt_data_pa <- list(N=4*length(HCC$Subject[HCC$Phase==2]), S=length(unique(HCC$Subject[HCC$Phase==2])), P=4, Slist=rep(HCC$Subject[HCC$Phase==2],4), trial=rep(HCC$Trial[HCC$Phase==2]-100,4), position=c(rep(1,length(HCC$Subject[HCC$Phase==2])),rep(2,length(HCC$Subject[HCC$Phase==2])),rep(3,length(HCC$Subject[HCC$Phase==2])),rep(4,length(HCC$Subject[HCC$Phase==2]))), react=c(HCC$React1[HCC$Phase==2],HCC$React2[HCC$Phase==2],HCC$React3[HCC$Phase==2],HCC$React4[HCC$Phase==2]))
rt_data_pa$Slist <- rt_data_pa$Slist[!is.na(rt_data_pa$react)]
rt_data_pa$position <- rt_data_pa$position[!is.na(rt_data_pa$react)]
rt_data_pa$trial <- rt_data_pa$trial[!is.na(rt_data_pa$react)]
rt_data_pa$react <- log(rt_data_pa$react[!is.na(rt_data_pa$react)])
rt_data_pa$N <- length(rt_data_pa$Slist)

#==================
#====Deploy Stan===
#==================

h_prior <- list(
	list(
		m_vec = rep(0,data_ph$S*(data_ph$P-1)),
		s_vec = rep(0,data_ph$S*(data_ph$P-1)),
		v_vec = rep(0,data_ph$S*(data_ph$P-1)),
		f_vec = rep(0,data_ph$S*(data_ph$P)),
		z = rep(0,4*(data_ph$P-1) + 1),
		bigmu = c(rep(0,data_ph$P-1),rep(0,data_ph$P-1),rep(0,data_ph$P-1),-2,-1.3,0,4.6),
		tau = rep(1,4*(data_ph$P-1) + 1),
		L_Omega = chol(matrix(0.1,nrow=13,ncol=13) + diag(0.9,nrow=13,ncol=13)),
		sigma = rep(2,4*(data_ph$P-1) + 1)
	)
)

h_rt_prior <- list(
	list(
		mu_k = rep(0,data_ph$P),
		mu_b = rep(0,data_ph$P),
		sigma_k = rep(1,data_ph$P),
		sigma_b = rep(1,data_ph$P),
		k = matrix(0,nrow=data_ph$S,ncol=data_ph$P),
		b = matrix(0,nrow=data_ph$S,ncol=data_ph$P),
		s = matrix(1,nrow=data_ph$S,ncol=data_ph$P)
	)
)

#==Warning: These are expected to take over 24 hours==
samp_w <- 1000
samp_m <- 4000

hum_cat_chain_photo_sim <- stan(model_code=human_cat_chain_model_string, data=data_ph, iter=samp_m+samp_w, warmup=samp_w, chains=1, init=h_prior, control = list(max_treedepth = 12))
h_ph_prms <- extract(hum_cat_chain_photo_sim,permuted=TRUE)
print(hum_cat_chain_photo_sim,pars="params_vec")
stan_trace(hum_cat_chain_photo_sim,pars="params_vec")

hum_cat_chain_paint_sim <- stan(model_code=human_cat_chain_model_string, data=data_pa, iter=samp_m+samp_w, warmup=samp_w, chains=1, init=h_prior, control = list(max_treedepth = 12))
h_pa_prms <- extract(hum_cat_chain_paint_sim,permuted=TRUE)
print(hum_cat_chain_paint_sim,pars="params_vec")
stan_trace(hum_cat_chain_paint_sim,pars="params_vec")

hum_cat_chain_photo_rt_sim <- stan(model_code=human_cat_chain_RT_string, data=rt_data_ph, iter=samp_m+samp_w, warmup=samp_w, chains=1, init=h_rt_prior, control = list(max_treedepth = 12))
h_ph_rt_prms <- extract(hum_cat_chain_photo_rt_sim,permuted=TRUE)
print(hum_cat_chain_photo_rt_sim,pars=c("mu_b","mu_k"))
stan_trace(hum_cat_chain_photo_rt_sim,pars=c("mu_b","mu_k"))

hum_cat_chain_paint_rt_sim <- stan(model_code=human_cat_chain_RT_string, data=rt_data_pa, iter=samp_m+samp_w, warmup=samp_w, chains=1, init=h_rt_prior, control = list(max_treedepth = 12))
h_pa_rt_prms <- extract(hum_cat_chain_paint_rt_sim,permuted=TRUE)
print(hum_cat_chain_paint_rt_sim,pars=c("mu_b","mu_k"))
stan_trace(hum_cat_chain_paint_rt_sim,pars=c("mu_b","mu_k"))


